import torch
from dataclasses import dataclass, field
from typing import Callable, Dict, Optional, Union, List
from tqdm import tqdm
from torch import nn
from transformers import GPT2LMHeadModel
from transformers import HfArgumentParser
from transformers import AutoConfig
from datasets import load_dataset
from transformers import AutoTokenizer

import pickle

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    
    device: Optional[str] = field(
        default="cuda", metadata={"help": "cuda/cpu"}
    )
        
    layer_id: Optional[int] = field(
        default=0, metadata={"help": "which layer to project!"}
    )    
    
    
    gradient_steps: Optional[int] = field(
        default=1000, metadata={"help": "How many steps to run gradient descent for!"}
    )    
    
    learning_rate: Optional[float] = field(
        default=1e-04, metadata={"help": "Learning rate for projection!"}
    )    
        
    batch_size: Optional[int] = field(
        default=16, metadata={"help": "Batch size for training!"}
    )
        
        
parser = HfArgumentParser((ModelArguments,))
model_args,  = parser.parse_args_into_dataclasses()
    
device = model_args.device
model_config = AutoConfig.from_pretrained(
    model_args.config_name if model_args.config_name else model_args.model_name_or_path,
    cache_dir=model_args.cache_dir
)


if 'gpt' in model_args.model_name_or_path:
    model_fn = GPT2LMHeadModel
else:
    raise NotImplmentedError

model = model_fn.from_pretrained(
    model_args.model_name_or_path,
    config=model_config,
    cache_dir=model_args.cache_dir,
)

model.to(device)
model.eval()

dataset = load_dataset("ptb_text_only", cache_dir='data', download_mode='reuse_cache_if_exists')
test_data = dataset['test']
train_data = dataset['train']
valid_data = dataset['validation']

tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir="../..")
tokenizer.pad_token = tokenizer.eos_token
pad_token_id=tokenizer.convert_tokens_to_ids(tokenizer.pad_token)


device=next(model.parameters()).device


test_sequences = []
max_seq_length = model_config.n_positions
for datum in test_data:
    sent = datum['sentence']
    test_sequences += [ sent ]
    
    
valid_sequences = []
for datum in valid_data:
    sent = datum['sentence']
    valid_sequences += [sent]
    
train_sequences = []
for datum in train_data:
    sent = datum['sentence']
    train_sequences += [sent]
    

#for layer in range(12):

#new_mlp_intermediate = nn.Linear(in_features=model_config.hidden_size, out_features=4*model_config.hidden_size, device=device)
#new_mlp_proj =  nn.Linear(in_features=4*model_config.hidden_size, out_features=model_config.hidden_size, device=device)   
proj1 = nn.Parameter( torch.zeros( 4*model_config.hidden_size, model_config.hidden_size , device=device ) )
proj2 = nn.Parameter( torch.zeros( model_config.hidden_size, 4*model_config.hidden_size , device=device ) )

proj3 = nn.Parameter( torch.zeros( 4*model_config.hidden_size, model_config.hidden_size , device=device ) )
proj4 = nn.Parameter( torch.zeros( model_config.hidden_size, 4*model_config.hidden_size , device=device ) )


act = model.transformer.h[0].mlp.act
optimizer = torch.optim.SGD( [ proj1, proj2, proj3, proj4 ], lr=model_args.learning_rate )
optimizer.zero_grad()

batch_id = 0
batch_size = model_args.batch_size
layer_id = model_args.layer_id
attn_ln = model.transformer.h[layer_id].ln_1
attn = model.transformer.h[layer_id].attn
mlp_ln = model.transformer.h[layer_id].ln_2
mlp = model.transformer.h[layer_id].mlp
attn_ln.eval()
attn.eval()
mlp_ln.eval()
mlp.eval()


with torch.no_grad():    
    #measure = torch.linalg.norm( torch.cat( [ mlp.c_fc.weight.T, mlp.c_fc.bias.unsqueeze(dim=-1), mlp.c_proj.weight ], axis=-1 ), ord=2, axis=-1)
    
    #print (torch.sort( measure,  descending=True )[0][:768] )
    #keep_indices = torch.argsort( measure, descending=True )[: new_mlp_intermediate.weight.shape[0] ]
    #expansion = 1.#torch.sum(measure) / torch.sum(measure[keep_indices])
    
    #new_mlp_intermediate.weight.copy_(mlp.c_fc.weight.T [keep_indices] )
    #new_mlp_intermediate.bias.copy_(mlp.c_fc.bias [keep_indices])
    
    #wt = mlp.c_proj.weight.T
    #U, S, Vh = torch.linalg.svd(wt)
    #new_mlp_proj.weight.copy_(expansion * mlp.c_proj.weight[keep_indices].T)
    
    wt = mlp.c_fc.weight.T
    u, s, vh = torch.linalg.svd(wt, full_matrices=False)
    
    proj1.copy_(u)
    proj2.copy_(u.T)
    #new_mlp_intermediate.weight.copy_(u @ u.T @ wt)
    #new_mlp_intermediate.bias.copy_(u @ u.T @ mlp.c_fc.bias)
    
    wt = mlp.c_proj.weight
    u, s, vh = torch.linalg.svd(wt, full_matrices=False)
    proj3.copy_(u)
    proj4.copy_(u.T)
    
    #new_mlp_proj.weight.copy_( wt.T @ u @ u.T )
    #new_mlp_proj.bias.copy_(mlp.c_proj.bias)



for _ in tqdm(range(model_args.gradient_steps), desc='Gradient descent for new mlp'):
    batch_sentences =  tokenizer(train_sequences [ batch_id * batch_size : (batch_id + 1) * batch_size ], padding='longest', max_length=max_seq_length, truncation=True) 
    batch_input, batch_mask = torch.tensor( batch_sentences['input_ids'] ).to(device), torch.tensor( batch_sentences['attention_mask'] ).to(device)
    
    batch_id = (batch_id + 1) % ( len( train_sequences ) // batch_size )
    
    mask = torch.ne(batch_input, pad_token_id).float()
    
    with torch.no_grad():
        output = model(batch_input, attention_mask=batch_mask, output_hidden_states=True)[-1]
        input_ = mlp_ln( attn( attn_ln( output[layer_id] ) )[0] + output[layer_id] )
        output_ = mlp(input_)
        
    inter_prediction = act( (input_ @ mlp.c_fc.weight + mlp.c_fc.bias) @ proj1 @ proj2 ) 
    prediction = inter_prediction @ proj3 @ proj4 @ mlp.c_proj.weight  + mlp.c_proj.bias
    
    #new_mlp_proj( act( new_mlp_intermediate (input_) ) )   
    loss = torch.sum( ((output_ - prediction) * mask.unsqueeze(dim=-1)) ** 2 ) / mask.sum()  
    #inf_loss = torch.max( torch.absolute((output_ - prediction) * mask.unsqueeze(dim=-1)) )
    loss.backward()
    
    
    print ("Train Loss:", loss.item())
    optimizer.step()
    
    optimizer.zero_grad()
    
    
test_loss = 0. 
total = 0
for batch_id in tqdm( range( len(test_sequences) // batch_size ) ):
    batch_sentences =  tokenizer(train_sequences [ batch_id * batch_size : (batch_id + 1) * batch_size ], padding='longest', max_length=max_seq_length, truncation=True) 
    
    batch_input, batch_mask = torch.tensor( batch_sentences['input_ids'] ).to(device), torch.tensor( batch_sentences['attention_mask'] ).to(device)
    
    mask = torch.ne(batch_input, pad_token_id).float()
    
    with torch.no_grad():
        output = model(batch_input, attention_mask=batch_mask, output_hidden_states=True)[-1]
        input_ = mlp_ln( attn( attn_ln( output[layer_id] ) )[0] + output[layer_id] )
        output_ = mlp(input_)
        
        inter_prediction = act( (input_ @ mlp.c_fc.weight + mlp.c_fc.bias) @ proj1 @ proj2 ) 
        prediction = inter_prediction @ proj3 @ proj4 @ mlp.c_proj.weight  + mlp.c_proj.bias
    
        #new_mlp_proj( act( new_mlp_intermediate (input_) ) )   
        loss = torch.sum( ((output_ - prediction) * mask.unsqueeze(dim=-1)) ** 2 ) 
        
        test_loss += loss.item()
        total += mask.sum()
    #print ("Train Loss:", loss.item())
    #optimizer.step()
    
    #optimizer.zero_grad()
        

print ("Average Test loss", test_loss/total)

pickle.dump([proj1.detach().cpu().numpy(), proj2.detach().cpu().numpy(), proj3.detach().cpu().numpy(), proj4.detach().cpu().numpy()], open('projections/projection_'+str(layer_id)+'.pkl', 'wb') )

